from datasets import load_dataset
import sys
import math
import dask.dataframe as da
import pandas as pd
path = sys.argv[1]
dataset = load_dataset(path, split="train")

datas = []
for line in dataset:
    datas.append(line)
all_shards = math.ceil(len(datas)/100000)
for i in range(all_shards):
    idx = "00"+str(i + 1000)[1:]
    all_id = "00"+str(all_shards + 1000)[1:]
    dataset = pd.DataFrame(datas[i*100000:(i+1)*100000])
    print(f"saving shard {idx}...")
    dataset.to_parquet(f"{path}_split/train-{idx}-of-{all_id}.parquet")
    
#dataset.save_to_disk(f"{path}_split")
